#!/usr/bin/env python2

"""
    thermalctld
    Thermal control daemon for SONiC
"""

try:
    import os
    import sys
    import time
    import signal
    import threading
    from datetime import datetime
    from sonic_daemon_base import daemon_base
    from sonic_daemon_base.daemon_base import Logger
    from sonic_daemon_base.daemon_base import DaemonBase
    from sonic_daemon_base.task_base import ProcessTaskBase
except ImportError as e:
    raise ImportError(str(e) + " - required module not found")

try:
    from swsscommon import swsscommon
except ImportError as e:
    from tests import mock_swsscommon as swsscommon

SYSLOG_IDENTIFIER = 'thermalctld'
NOT_AVAILABLE = 'N/A'
logger = Logger(SYSLOG_IDENTIFIER)


# utility functions

# try get information from platform API and return a default value if caught NotImplementedError
def try_get(callback, default=NOT_AVAILABLE):
    """
    Handy function to invoke the callback and catch NotImplementedError
    :param callback: Callback to be invoked
    :param default: Default return value if exception occur
    :return: Default return value if exception occur else return value of the callback
    """
    try:
        ret = callback()
        if ret is None:
            ret = default
    except NotImplementedError:
        ret = default

    return ret


def log_on_status_changed(normal_status, normal_log, abnormal_log):
    """
    Log when any status changed
    :param normal_status: Expected status.
    :param normal_log: Log string for expected status.
    :param abnormal_log: Log string for unexpected status
    :return:
    """
    if normal_status:
        logger.log_notice(normal_log)
    else:
        logger.log_warning(abnormal_log)


class FanStatus(object):
    absence_fan_count = 0
    fault_fan_count = 0

    def __init__(self, fan=None, is_psu_fan=False):
        """
        Constructor of FanStatus
        """
        self.fan = fan
        self.is_psu_fan = is_psu_fan
        self.presence = True
        self.status = True
        self.under_speed = False
        self.over_speed = False
        self.invalid_direction = False

    @classmethod
    def get_bad_fan_count(cls):
        return cls.absence_fan_count + cls.fault_fan_count

    @classmethod
    def reset_fan_counter(cls):
        cls.absence_fan_count = 0
        cls.fault_fan_count = 0

    def set_presence(self, presence):
        """
        Set and cache Fan presence status
        :param presence: Fan presence status
        :return: True if status changed else False
        """
        if not presence and not self.is_psu_fan:
            FanStatus.absence_fan_count += 1

        if presence == self.presence:
            return False

        self.presence = presence
        return True

    def set_fault_status(self, status):
        """
        Set and cache Fan fault status
        :param status: Fan fault status, False indicate Fault
        :return: True if status changed else False
        """
        if not status:
            FanStatus.fault_fan_count += 1

        if status == self.status:
            return False

        self.status = status
        return True

    def _check_speed_value_available(self, speed, target_speed, tolerance, current_status):
        if speed == NOT_AVAILABLE or target_speed == NOT_AVAILABLE or tolerance == NOT_AVAILABLE:
            if tolerance > 100 or tolerance < 0:
                logger.log_warning('Invalid tolerance value: {}'.format(tolerance))
                return False

            if current_status is True:
                logger.log_warning('Fan speed or target_speed or tolerance become unavailable, '
                                   'speed={}, target_speed={}, tolerance={}'.format(speed, target_speed, tolerance))
            return False
        return True

    def set_under_speed(self, speed, target_speed, tolerance):
        """
        Set and cache Fan under speed status
        :param speed: Fan speed
        :param target_speed: Fan target speed
        :param tolerance: Threshold between Fan speed and target speed
        :return: True if status changed else False
        """
        if not self._check_speed_value_available(speed, target_speed, tolerance, self.under_speed):
            old_status = self.under_speed
            self.under_speed = False
            return old_status != self.under_speed

        status = speed < target_speed * (1 - float(tolerance) / 100)
        if status == self.under_speed:
            return False

        self.under_speed = status
        return True

    def set_over_speed(self, speed, target_speed, tolerance):
        """
        Set and cache Fan over speed status
        :param speed: Fan speed
        :param target_speed: Fan target speed
        :param tolerance: Threshold between Fan speed and target speed
        :return: True if status changed else False
        """
        if not self._check_speed_value_available(speed, target_speed, tolerance, self.over_speed):
            old_status = self.over_speed
            self.over_speed = False
            return old_status != self.over_speed

        status = speed > target_speed * (1 + float(tolerance) / 100)
        if status == self.over_speed:
            return False

        self.over_speed = status
        return True

    def is_ok(self):
        """
        Indicate the Fan works as expect
        :return: True if Fan works normal else False
        """
        return self.presence and \
               self.status and \
               not self.under_speed and \
               not self.over_speed and \
               not self.invalid_direction


#
# FanUpdater ===================================================================
#
class FanUpdater(object):
    # Fan information table name in database
    FAN_INFO_TABLE_NAME = 'FAN_INFO'

    def __init__(self, chassis):
        """
        Constructor for FanUpdater
        :param chassis: Object representing a platform chassis
        """
        self.chassis = chassis
        self.fan_status_dict = {}
        state_db = daemon_base.db_connect("STATE_DB")
        self.table = swsscommon.Table(state_db, FanUpdater.FAN_INFO_TABLE_NAME)

    def deinit(self):
        """
        Destructor of FanUpdater
        :return:
        """
        for name in self.fan_status_dict.keys():
            self.table._del(name)

    def update(self):
        """
        Update all Fan information to database
        :return:
        """
        logger.log_debug("Start fan updating")
        old_bad_fan_count = FanStatus.get_bad_fan_count()
        FanStatus.reset_fan_counter()

        fan_index = 0
        for drawer in self.chassis.get_all_fan_drawers():
            for fan in drawer.get_all_fans():
                try:
                    self._refresh_fan_status(drawer, fan, fan_index)
                except Exception as e:
                    logger.log_warning('Failed to update FAN status - {}'.format(e))
                fan_index += 1

        for psu_index, psu in enumerate(self.chassis.get_all_psus()):
            psu_name = try_get(psu.get_name, 'PSU {}'.format(psu_index))
            for fan_index, fan in enumerate(psu.get_all_fans()):
                try:
                    self._refresh_fan_status(None, fan, fan_index, '{} FAN'.format(psu_name), True)
                except Exception as e:
                    logger.log_warning('Failed to update PSU FAN status - {}'.format(e))

        self._update_led_color()
        
        bad_fan_count = FanStatus.get_bad_fan_count()
        if bad_fan_count > 0 and old_bad_fan_count != bad_fan_count:
            logger.log_warning("Insufficient number of working fans warning: {} fans are not working.".format(
                bad_fan_count
            ))
        elif old_bad_fan_count > 0 and bad_fan_count == 0:
            logger.log_notice("Insufficient number of working fans warning cleared: all fans are back to normal.")

        logger.log_debug("End fan updating")

    def _refresh_fan_status(self, fan_drawer, fan, index, name_prefix='FAN', is_psu_fan=False):
        """
        Get Fan status by platform API and write to database for a given Fan
        :param fan_drawer: Object representing a platform Fan drawer
        :param fan: Object representing a platform Fan
        :param index: Index of the Fan object in the platform
        :param name_prefix: name prefix of Fan object if Fan.get_name not presented
        :return:
        """
        drawer_name = NOT_AVAILABLE if is_psu_fan else str(try_get(fan_drawer.get_name))
        fan_name = try_get(fan.get_name, '{} {}'.format(name_prefix, index + 1))
        if fan_name not in self.fan_status_dict:
            self.fan_status_dict[fan_name] = FanStatus(fan, is_psu_fan)

        fan_status = self.fan_status_dict[fan_name]

        speed = NOT_AVAILABLE
        speed_tolerance = NOT_AVAILABLE
        speed_target = NOT_AVAILABLE
        fan_fault_status = NOT_AVAILABLE
        fan_direction = NOT_AVAILABLE
        presence = try_get(fan.get_presence, False)
        if presence:
            speed = try_get(fan.get_speed)
            speed_tolerance = try_get(fan.get_speed_tolerance)
            speed_target = try_get(fan.get_target_speed)
            fan_fault_status = try_get(fan.get_status, False)
            fan_direction = try_get(fan.get_direction)

        set_led = False
        if fan_status.set_presence(presence):
            set_led = True
            log_on_status_changed(fan_status.presence,
                                  'Fan removed warning cleared: {} was inserted.'.format(fan_name),
                                  'Fan removed warning: {} was removed from '
                                  'the system, potential overheat hazard'.format(fan_name)
                                  )

        if presence and fan_status.set_fault_status(fan_fault_status):
            set_led = True
            log_on_status_changed(fan_status.status,
                                  'Fan fault warning cleared: {} is back to normal.'.format(fan_name),
                                  'Fan fault warning: {} is broken.'.format(fan_name)
                                  )

        if presence and fan_status.set_under_speed(speed, speed_target, speed_tolerance):
            set_led = True
            log_on_status_changed(not fan_status.under_speed,
                                  'Fan low speed warning cleared: {} speed is back to normal.'.format(fan_name),
                                  'Fan low speed warning: {} current speed={}, target speed={}, tolerance={}.'.
                                  format(fan_name, speed, speed_target, speed_tolerance)
                                  )

        if presence and fan_status.set_over_speed(speed, speed_target, speed_tolerance):
            set_led = True
            log_on_status_changed(not fan_status.over_speed,
                                  'Fan high speed warning cleared: {} speed is back to normal.'.format(fan_name),
                                  'Fan high speed warning: {} target speed={}, current speed={}, tolerance={}.'.
                                  format(fan_name, speed_target, speed, speed_tolerance)
                                  )

        # TODO: handle invalid fan direction

        # We don't set PSU led here, PSU led will be handled in psud
        if set_led:
            if not is_psu_fan:
                self._set_fan_led(fan_drawer, fan, fan_name, fan_status)

        fvs = swsscommon.FieldValuePairs(
            [('presence', str(presence)),
             ('drawer_name', drawer_name),
             ('model', str(try_get(fan.get_model))),
             ('serial', str(try_get(fan.get_serial))),
             ('status', str(fan_fault_status)),
             ('direction', str(fan_direction)),
             ('speed', str(speed)),
             ('speed_tolerance', str(speed_tolerance)),
             ('speed_target', str(speed_target)),
             ('timestamp', datetime.now().strftime('%Y%m%d %H:%M:%S'))
             ])

        self.table.set(fan_name, fvs)

    def _set_fan_led(self, fan_drawer, fan, fan_name, fan_status):
        """
        Set fan led according to current status
        :param fan_drawer: Object representing a platform Fan drawer or PSU
        :param fan: Object representing a platform Fan
        :param fan_name: Name of the Fan object in case any vendor not implement Fan.get_name
        :param fan_status: Object representing the FanStatus
        :return:
        """
        try:
            if fan_status.is_ok():
                fan.set_status_led(fan.STATUS_LED_COLOR_GREEN)
                fan_drawer.set_status_led(fan.STATUS_LED_COLOR_GREEN)
            else:
                # TODO: wait for Kebo to define the mapping of fan status to led color,
                # just set it to red so far
                fan.set_status_led(fan.STATUS_LED_COLOR_RED)
                fan_drawer.set_status_led(fan.STATUS_LED_COLOR_RED)
        except NotImplementedError as e:
            logger.log_warning('Failed to set led to fan, set_status_led not implemented')

    def _update_led_color(self):
        for fan_name, fan_status in self.fan_status_dict.items():
            try:
                fvs = swsscommon.FieldValuePairs([
                    ('led_status', str(try_get(fan_status.fan.get_status_led)))
                ])
            except Exception as e:
                logger.log_warning('Failed to get led status for fan')
                fvs = swsscommon.FieldValuePairs([
                    ('led_status', NOT_AVAILABLE)
                ])
            self.table.set(fan_name, fvs)


class TemperatureStatus(object):
    TEMPERATURE_DIFF_THRESHOLD = 10

    def __init__(self):
        self.temperature = None
        self.over_temperature = False
        self.under_temperature = False

    def set_temperature(self, name, temperature):
        """
        Record temperature changes, if it changed too fast, raise a warning.
        :param name: Name of the thermal.
        :param temperature: New temperature value.
        :return:
        """
        if temperature == NOT_AVAILABLE:
            if self.temperature is not None:
                logger.log_warning('Temperature of {} become unavailable'.format(name))
                self.temperature = None
            return

        if self.temperature is None:
            self.temperature = temperature
        else:
            diff = abs(temperature - self.temperature)
            if diff > TemperatureStatus.TEMPERATURE_DIFF_THRESHOLD:
                logger.log_warning(
                    'Temperature of {} changed too fast, from {} to {}, please check your hardware'.format(
                        name, self.temperature, temperature))
            self.temperature = temperature

    def _check_temperature_value_available(self, temperature, threshold, current_status):
        if temperature == NOT_AVAILABLE or threshold == NOT_AVAILABLE:
            if current_status is True:
                logger.log_warning('Thermal temperature or threshold become unavailable, '
                                   'temperature={}, threshold={}'.format(temperature, threshold))
            return False
        return True

    def set_over_temperature(self, temperature, threshold):
        """
        Set over temperature status
        :param temperature: Temperature
        :param threshold: High threshold
        :return: True if over temperature status changed else False
        """
        if not self._check_temperature_value_available(temperature, threshold, self.over_temperature):
            old_status = self.over_temperature
            self.over_temperature = False
            return old_status != self.over_temperature

        status = temperature > threshold
        if status == self.over_temperature:
            return False

        self.over_temperature = status
        return True

    def set_under_temperature(self, temperature, threshold):
        """
        Set over temperature status
        :param temperature: Temperature
        :param threshold: Low threshold
        :return: True if under temperature status changed else False
        """
        if not self._check_temperature_value_available(temperature, threshold, self.under_temperature):
            old_status = self.under_temperature
            self.under_temperature = False
            return old_status != self.under_temperature

        status = temperature < threshold
        if status == self.under_temperature:
            return False

        self.under_temperature = status
        return True


#
# TemperatureUpdater  ======================================================================
#
class TemperatureUpdater(object):
    # Temperature information table name in database
    TEMPER_INFO_TABLE_NAME = 'TEMPERATURE_INFO'

    def __init__(self, chassis):
        """
        Constructor of TemperatureUpdater 
        :param chassis: Object representing a platform chassis
        """
        self.chassis = chassis
        self.temperature_status_dict = {}
        state_db = daemon_base.db_connect("STATE_DB")
        self.table = swsscommon.Table(state_db, TemperatureUpdater.TEMPER_INFO_TABLE_NAME)

    def deinit(self):
        """
        Destructor of TemperatureUpdater 
        :return:
        """
        for name in self.temperature_status_dict.keys():
            self.table._del(name)

    def update(self):
        """
        Update all temperature information to database
        :return:
        """
        logger.log_debug("Start temperature updating") 
        for index, thermal in enumerate(self.chassis.get_all_thermals()):
            try:
                self._refresh_temperature_status(thermal, index)
            except Exception as e:
                logger.log_warning('Failed to update thermal status - {}'.format(e))

        logger.log_debug("End temperature updating")

    def _refresh_temperature_status(self, thermal, index):
        """
        Get temperature status by platform API and write to database
        :param thermal: Object representing a platform thermal zone
        :param index: Index of the thermal object in platform chassis
        :return:
        """
        name = try_get(thermal.get_name, 'Thermal {}'.format(index + 1))
        if name not in self.temperature_status_dict:
            self.temperature_status_dict[name] = TemperatureStatus()

        temperature_status = self.temperature_status_dict[name]

        high_threshold = NOT_AVAILABLE
        low_threshold = NOT_AVAILABLE
        high_critical_threshold = NOT_AVAILABLE
        low_critical_threshold = NOT_AVAILABLE
        temperature = try_get(thermal.get_temperature)
        if temperature != NOT_AVAILABLE:
            temperature_status.set_temperature(name, temperature)
            high_threshold = try_get(thermal.get_high_threshold)
            low_threshold = try_get(thermal.get_low_threshold)
            high_critical_threshold = try_get(thermal.get_high_critical_threshold)
            low_critical_threshold = try_get(thermal.get_low_critical_threshold)
            
        warning = False
        if temperature != NOT_AVAILABLE and temperature_status.set_over_temperature(temperature, high_threshold):
            log_on_status_changed(not temperature_status.over_temperature,
                                  'High temperature warning cleared: {} temperature restore to {}C, high threshold {}C.'.
                                  format(name, temperature, high_threshold),
                                  'High temperature warning: {} current temperature {}C, high threshold {}C'.
                                  format(name, temperature, high_threshold)
                                  )
        warning = warning | temperature_status.over_temperature

        if temperature != NOT_AVAILABLE and temperature_status.set_under_temperature(temperature, low_threshold):
            log_on_status_changed(not temperature_status.under_temperature,
                                  'Low temperature warning cleared: {} temperature restore to {}C, low threshold {}C.'.
                                  format(name, temperature, low_threshold),
                                  'Low temperature warning: {} current temperature {}C, low threshold {}C'.
                                  format(name, temperature, low_threshold)
                                  )
        warning = warning | temperature_status.under_temperature

        fvs = swsscommon.FieldValuePairs(
            [('temperature', str(temperature)),
             ('high_threshold', str(high_threshold)),
             ('low_threshold', str(low_threshold)),
             ('warning_status', str(warning)),
             ('critical_high_threshold', str(high_critical_threshold)),
             ('critical_low_threshold', str(low_critical_threshold)),
             ('timestamp', datetime.now().strftime('%Y%m%d %H:%M:%S'))
             ])

        self.table.set(name, fvs)


class ThermalMonitor(ProcessTaskBase):
    # Initial update interval
    INITIAL_INTERVAL = 5
    # Update interval value
    UPDATE_INTERVAL = 60
    # Update elapse threshold. If update used time is larger than the value, generate a warning log.
    UPDATE_ELAPSE_THRESHOLD = 30

    def __init__(self, chassis):
        """
        Constructor for ThermalMonitor
        :param chassis: Object representing a platform chassis
        """
        ProcessTaskBase.__init__(self)
        self.fan_updater = FanUpdater(chassis)
        self.temperature_updater = TemperatureUpdater(chassis)

    def task_worker(self):
        """
        Thread function to handle Fan status update and temperature status update
        :return:
        """
        logger.log_info("Start thermal monitoring loop")

        # Start loop to update fan, temperature info in DB periodically
        wait_time = ThermalMonitor.INITIAL_INTERVAL
        while not self.task_stopping_event.wait(wait_time):
            begin = time.time()
            self.fan_updater.update()
            self.temperature_updater.update()
            elapse = time.time() - begin
            if elapse < ThermalMonitor.UPDATE_INTERVAL:
                wait_time = ThermalMonitor.UPDATE_INTERVAL - elapse
            else:
                wait_time = ThermalMonitor.INITIAL_INTERVAL

            if elapse > ThermalMonitor.UPDATE_ELAPSE_THRESHOLD:
                logger.log_warning('Update fan and temperature status takes {} seconds, '
                                   'there might be performance risk'.format(elapse))

        self.fan_updater.deinit()
        self.temperature_updater.deinit()

        logger.log_info("Stop thermal monitoring loop")


#
# Daemon =======================================================================
#
class ThermalControlDaemon(DaemonBase):
    # Interval to run thermal control logic
    INTERVAL = 60
    POLICY_FILE = '/usr/share/sonic/platform/thermal_policy.json'

    def __init__(self):
        """
        Constructor of ThermalControlDaemon
        """
        DaemonBase.__init__(self)
        self.stop_event = threading.Event()

    # Signal handler
    def signal_handler(self, sig, frame):
        """
        Signal handler
        :param sig: Signal number
        :param frame: not used
        :return:
        """
        if sig == signal.SIGHUP:
            logger.log_info("Caught SIGHUP - ignoring...")
        elif sig == signal.SIGINT:
            logger.log_info("Caught SIGINT - exiting...")
            self.stop_event.set()
        elif sig == signal.SIGTERM:
            logger.log_info("Caught SIGTERM - exiting...")
            self.stop_event.set()
        else:
            logger.log_warning("Caught unhandled signal '" + sig + "'")

    def run(self):
        """
        Run main logical of this daemon
        :return:
        """
        logger.log_info("Starting up...")

        import sonic_platform.platform
        chassis = sonic_platform.platform.Platform().get_chassis()

        thermal_monitor = ThermalMonitor(chassis)
        thermal_monitor.task_run()

        thermal_manager = None
        try:
            thermal_manager = chassis.get_thermal_manager()
            if thermal_manager:
                thermal_manager.initialize()
                thermal_manager.load(ThermalControlDaemon.POLICY_FILE)
                thermal_manager.init_thermal_algorithm(chassis)
        except Exception as e:
            logger.log_error('Caught exception while initializing thermal manager - {}'.format(e))

        while not self.stop_event.wait(ThermalControlDaemon.INTERVAL):
            try:
                if thermal_manager:
                    thermal_manager.run_policy(chassis)
            except Exception as e:
                logger.log_error('Caught exception while running thermal policy - {}'.format(e))

        try:
            if thermal_manager:
                thermal_manager.deinitialize()
        except Exception as e:
            logger.log_error('Caught exception while destroy thermal manager - {}'.format(e))

        thermal_monitor.task_stop()

        logger.log_info("Shutdown...")


#
# Main =========================================================================
#
def main():
    thermal_control = ThermalControlDaemon()
    thermal_control.run()


if __name__ == '__main__':
    main()
